#include "prism.h"

#include <stdio.h>

static inline uint32_t
pm_ptrdifft_to_u32(ptrdiff_t value) {
    assert(value >= 0 && ((unsigned long) value) < UINT32_MAX);
    return (uint32_t) value;
}

static inline uint32_t
pm_sizet_to_u32(size_t value) {
    assert(value < UINT32_MAX);
    return (uint32_t) value;
}

static void
pm_serialize_location(const pm_parser_t *parser, const pm_location_t *location, pm_buffer_t *buffer) {
    assert(location->start);
    assert(location->end);
    assert(location->start <= location->end);

    pm_buffer_append_varuint(buffer, pm_ptrdifft_to_u32(location->start - parser->start));
    pm_buffer_append_varuint(buffer, pm_ptrdifft_to_u32(location->end - location->start));
}

static void
pm_serialize_string(pm_parser_t *parser, pm_string_t *string, pm_buffer_t *buffer) {
    switch (string->type) {
        case PM_STRING_SHARED: {
            pm_buffer_append_byte(buffer, 1);
            pm_buffer_append_varuint(buffer, pm_ptrdifft_to_u32(pm_string_source(string) - parser->start));
            pm_buffer_append_varuint(buffer, pm_sizet_to_u32(pm_string_length(string)));
            break;
        }
        case PM_STRING_OWNED:
        case PM_STRING_CONSTANT: {
            uint32_t length = pm_sizet_to_u32(pm_string_length(string));
            pm_buffer_append_byte(buffer, 2);
            pm_buffer_append_varuint(buffer, length);
            pm_buffer_append_bytes(buffer, pm_string_source(string), length);
            break;
        }
        case PM_STRING_MAPPED:
            assert(false && "Cannot serialize mapped strings.");
            break;
    }
}

static void
pm_serialize_node(pm_parser_t *parser, pm_node_t *node, pm_buffer_t *buffer) {
    pm_buffer_append_byte(buffer, (uint8_t) PM_NODE_TYPE(node));

    size_t offset = buffer->length;

    pm_serialize_location(parser, &node->location, buffer);

    switch (PM_NODE_TYPE(node)) {
        // We do not need to serialize a ScopeNode ever as
        // it is not part of the AST
        case PM_SCOPE_NODE:
            return;
        <%- nodes.each do |node| -%>
        case <%= node.type %>: {
            <%- if node.needs_serialized_length? -%>
            // serialize length
            // encoding of location u32s make us need to save this offset.
            size_t length_offset = buffer->length;
            pm_buffer_append_string(buffer, "\0\0\0\0", 4); /* consume 4 bytes, updated below */
            <%- end -%>
            <%- node.fields.each do |field| -%>
            <%- case field -%>
            <%- when Prism::NodeField -%>
            pm_serialize_node(parser, (pm_node_t *)((pm_<%= node.human %>_t *)node)-><%= field.name %>, buffer);
            <%- when Prism::OptionalNodeField -%>
            if (((pm_<%= node.human %>_t *)node)-><%= field.name %> == NULL) {
                pm_buffer_append_byte(buffer, 0);
            } else {
                pm_serialize_node(parser, (pm_node_t *)((pm_<%= node.human %>_t *)node)-><%= field.name %>, buffer);
            }
            <%- when Prism::StringField -%>
            pm_serialize_string(parser, &((pm_<%= node.human %>_t *)node)-><%= field.name %>, buffer);
            <%- when Prism::NodeListField -%>
            uint32_t <%= field.name %>_size = pm_sizet_to_u32(((pm_<%= node.human %>_t *)node)-><%= field.name %>.size);
            pm_buffer_append_varuint(buffer, <%= field.name %>_size);
            for (uint32_t index = 0; index < <%= field.name %>_size; index++) {
                pm_serialize_node(parser, (pm_node_t *) ((pm_<%= node.human %>_t *)node)-><%= field.name %>.nodes[index], buffer);
            }
            <%- when Prism::ConstantField, Prism::OptionalConstantField -%>
            pm_buffer_append_varuint(buffer, pm_sizet_to_u32(((pm_<%= node.human %>_t *)node)-><%= field.name %>));
            <%- when Prism::ConstantListField -%>
            uint32_t <%= field.name %>_size = pm_sizet_to_u32(((pm_<%= node.human %>_t *)node)-><%= field.name %>.size);
            pm_buffer_append_varuint(buffer, <%= field.name %>_size);
            for (uint32_t index = 0; index < <%= field.name %>_size; index++) {
                pm_buffer_append_varuint(buffer, pm_sizet_to_u32(((pm_<%= node.human %>_t *)node)-><%= field.name %>.ids[index]));
            }
            <%- when Prism::LocationField -%>
            <%- if field.should_be_serialized? -%>
            pm_serialize_location(parser, &((pm_<%= node.human %>_t *)node)-><%= field.name %>, buffer);
            <%- end -%>
            <%- when Prism::OptionalLocationField -%>
            <%- if field.should_be_serialized? -%>
            if (((pm_<%= node.human %>_t *)node)-><%= field.name %>.start == NULL) {
                pm_buffer_append_byte(buffer, 0);
            } else {
                pm_buffer_append_byte(buffer, 1);
                pm_serialize_location(parser, &((pm_<%= node.human %>_t *)node)-><%= field.name %>, buffer);
            }
            <%- end -%>
            <%- when Prism::UInt8Field -%>
            pm_buffer_append_byte(buffer, ((pm_<%= node.human %>_t *)node)-><%= field.name %>);
            <%- when Prism::UInt32Field -%>
            pm_buffer_append_varuint(buffer, ((pm_<%= node.human %>_t *)node)-><%= field.name %>);
            <%- when Prism::FlagsField -%>
            pm_buffer_append_varuint(buffer, (uint32_t)(node->flags & ~PM_NODE_FLAG_COMMON_MASK));
            <%- else -%>
            <%- raise -%>
            <%- end -%>
            <%- end -%>
            <%- if node.needs_serialized_length? -%>
            // serialize length
            uint32_t length = pm_sizet_to_u32(buffer->length - offset - sizeof(uint32_t));
            memcpy(buffer->value + length_offset, &length, sizeof(uint32_t));
            <%- end -%>
            break;
        }
        <%- end -%>
    }
}

static void
pm_serialize_comment(pm_parser_t *parser, pm_comment_t *comment, pm_buffer_t *buffer) {
    // serialize type
    pm_buffer_append_byte(buffer, (uint8_t) comment->type);

    // serialize location
    pm_serialize_location(parser, &comment->location, buffer);
}

/**
 * Serialize the given list of comments to the given buffer.
 */
void
pm_serialize_comment_list(pm_parser_t *parser, pm_list_t *list, pm_buffer_t *buffer) {
    pm_buffer_append_varuint(buffer, pm_sizet_to_u32(pm_list_size(list)));

    pm_comment_t *comment;
    for (comment = (pm_comment_t *) list->head; comment != NULL; comment = (pm_comment_t *) comment->node.next) {
        pm_serialize_comment(parser, comment, buffer);
    }
}

static void
pm_serialize_magic_comment(pm_parser_t *parser, pm_magic_comment_t *magic_comment, pm_buffer_t *buffer) {
    // serialize key location
    pm_buffer_append_varuint(buffer, pm_ptrdifft_to_u32(magic_comment->key_start - parser->start));
    pm_buffer_append_varuint(buffer, pm_sizet_to_u32(magic_comment->key_length));

    // serialize value location
    pm_buffer_append_varuint(buffer, pm_ptrdifft_to_u32(magic_comment->value_start - parser->start));
    pm_buffer_append_varuint(buffer, pm_sizet_to_u32(magic_comment->value_length));
}

static void
pm_serialize_magic_comment_list(pm_parser_t *parser, pm_list_t *list, pm_buffer_t *buffer) {
    pm_buffer_append_varuint(buffer, pm_sizet_to_u32(pm_list_size(list)));

    pm_magic_comment_t *magic_comment;
    for (magic_comment = (pm_magic_comment_t *) list->head; magic_comment != NULL; magic_comment = (pm_magic_comment_t *) magic_comment->node.next) {
        pm_serialize_magic_comment(parser, magic_comment, buffer);
    }
}

static void
pm_serialize_data_loc(const pm_parser_t *parser, pm_buffer_t *buffer) {
    if (parser->data_loc.end == NULL) {
        pm_buffer_append_byte(buffer, 0);
    } else {
        pm_buffer_append_byte(buffer, 1);
        pm_serialize_location(parser, &parser->data_loc, buffer);
    }
}

static void
pm_serialize_diagnostic(pm_parser_t *parser, pm_diagnostic_t *diagnostic, pm_buffer_t *buffer) {
    // serialize message
    size_t message_length = strlen(diagnostic->message);
    pm_buffer_append_varuint(buffer, pm_sizet_to_u32(message_length));
    pm_buffer_append_string(buffer, diagnostic->message, message_length);

    // serialize location
    pm_serialize_location(parser, &diagnostic->location, buffer);

    pm_buffer_append_byte(buffer, diagnostic->level);
}

static void
pm_serialize_diagnostic_list(pm_parser_t *parser, pm_list_t *list, pm_buffer_t *buffer) {
    pm_buffer_append_varuint(buffer, pm_sizet_to_u32(pm_list_size(list)));

    pm_diagnostic_t *diagnostic;
    for (diagnostic = (pm_diagnostic_t *) list->head; diagnostic != NULL; diagnostic = (pm_diagnostic_t *) diagnostic->node.next) {
        pm_serialize_diagnostic(parser, diagnostic, buffer);
    }
}

/**
 * Serialize the name of the encoding to the buffer.
 */
void
pm_serialize_encoding(const pm_encoding_t *encoding, pm_buffer_t *buffer) {
    size_t encoding_length = strlen(encoding->name);
    pm_buffer_append_varuint(buffer, pm_sizet_to_u32(encoding_length));
    pm_buffer_append_string(buffer, encoding->name, encoding_length);
}

#line <%= __LINE__ + 1 %> "<%= File.basename(__FILE__) %>"
/**
 * Serialize the encoding, metadata, nodes, and constant pool.
 */
void
pm_serialize_content(pm_parser_t *parser, pm_node_t *node, pm_buffer_t *buffer) {
    pm_serialize_encoding(parser->encoding, buffer);
    pm_buffer_append_varsint(buffer, parser->start_line);
<%- unless Prism::SERIALIZE_ONLY_SEMANTICS_FIELDS -%>
    pm_serialize_comment_list(parser, &parser->comment_list, buffer);
<%- end -%>
    pm_serialize_magic_comment_list(parser, &parser->magic_comment_list, buffer);
    pm_serialize_data_loc(parser, buffer);
    pm_serialize_diagnostic_list(parser, &parser->error_list, buffer);
    pm_serialize_diagnostic_list(parser, &parser->warning_list, buffer);

    // Here we're going to leave space for the offset of the constant pool in
    // the buffer.
    size_t offset = buffer->length;
    pm_buffer_append_zeroes(buffer, 4);

    // Next, encode the length of the constant pool.
    pm_buffer_append_varuint(buffer, parser->constant_pool.size);

    // Now we're going to serialize the content of the node.
    pm_serialize_node(parser, node, buffer);

    // Now we're going to serialize the offset of the constant pool back where
    // we left space for it.
    uint32_t length = pm_sizet_to_u32(buffer->length);
    memcpy(buffer->value + offset, &length, sizeof(uint32_t));

    // Now we're going to serialize the constant pool.
    offset = buffer->length;
    pm_buffer_append_zeroes(buffer, parser->constant_pool.size * 8);

    for (uint32_t index = 0; index < parser->constant_pool.capacity; index++) {
        pm_constant_pool_bucket_t *bucket = &parser->constant_pool.buckets[index];

        // If we find a constant at this index, serialize it at the correct
        // index in the buffer.
        if (bucket->id != 0) {
            pm_constant_t *constant = &parser->constant_pool.constants[bucket->id - 1];
            size_t buffer_offset = offset + ((((size_t)bucket->id) - 1) * 8);

            if (bucket->type == PM_CONSTANT_POOL_BUCKET_OWNED || bucket->type == PM_CONSTANT_POOL_BUCKET_CONSTANT) {
                // Since this is an owned or constant constant, we are going to
                // write its contents into the buffer after the constant pool.
                // So effectively in place of the source offset, we have a
                // buffer offset. We will add a leading 1 to indicate that this
                // is a buffer offset.
                uint32_t content_offset = pm_sizet_to_u32(buffer->length);
                uint32_t owned_mask = (uint32_t) (1 << 31);

                assert(content_offset < owned_mask);
                content_offset |= owned_mask;

                memcpy(buffer->value + buffer_offset, &content_offset, 4);
                pm_buffer_append_bytes(buffer, constant->start, constant->length);
            } else {
                // Since this is a shared constant, we are going to write its
                // source offset directly into the buffer.
                uint32_t source_offset = pm_ptrdifft_to_u32(constant->start - parser->start);
                memcpy(buffer->value + buffer_offset, &source_offset, 4);
            }

            // Now we can write the length of the constant into the buffer.
            uint32_t constant_length = pm_sizet_to_u32(constant->length);
            memcpy(buffer->value + buffer_offset + 4, &constant_length, 4);
        }
    }
}

static void
serialize_token(void *data, pm_parser_t *parser, pm_token_t *token) {
    pm_buffer_t *buffer = (pm_buffer_t *) data;

    pm_buffer_append_varuint(buffer, token->type);
    pm_buffer_append_varuint(buffer, pm_ptrdifft_to_u32(token->start - parser->start));
    pm_buffer_append_varuint(buffer, pm_ptrdifft_to_u32(token->end - token->start));
    pm_buffer_append_varuint(buffer, parser->lex_state);
}

/**
 * Lex the given source and serialize to the given buffer.
 */
PRISM_EXPORTED_FUNCTION void
pm_serialize_lex(pm_buffer_t *buffer, const uint8_t *source, size_t size, const char *data) {
    pm_options_t options = { 0 };
    pm_options_read(&options, data);

    pm_parser_t parser;
    pm_parser_init(&parser, source, size, &options);

    pm_lex_callback_t lex_callback = (pm_lex_callback_t) {
        .data = (void *) buffer,
        .callback = serialize_token,
    };

    parser.lex_callback = &lex_callback;
    pm_node_t *node = pm_parse(&parser);

    // Append 0 to mark end of tokens.
    pm_buffer_append_byte(buffer, 0);

    pm_serialize_encoding(parser.encoding, buffer);
    pm_buffer_append_varsint(buffer, parser.start_line);
    pm_serialize_comment_list(&parser, &parser.comment_list, buffer);
    pm_serialize_magic_comment_list(&parser, &parser.magic_comment_list, buffer);
    pm_serialize_data_loc(&parser, buffer);
    pm_serialize_diagnostic_list(&parser, &parser.error_list, buffer);
    pm_serialize_diagnostic_list(&parser, &parser.warning_list, buffer);

    pm_node_destroy(&parser, node);
    pm_parser_free(&parser);
    pm_options_free(&options);
}

/**
 * Parse and serialize both the AST and the tokens represented by the given
 * source to the given buffer.
 */
PRISM_EXPORTED_FUNCTION void
pm_serialize_parse_lex(pm_buffer_t *buffer, const uint8_t *source, size_t size, const char *data) {
    pm_options_t options = { 0 };
    pm_options_read(&options, data);

    pm_parser_t parser;
    pm_parser_init(&parser, source, size, &options);

    pm_lex_callback_t lex_callback = (pm_lex_callback_t) {
        .data = (void *) buffer,
        .callback = serialize_token,
    };

    parser.lex_callback = &lex_callback;
    pm_node_t *node = pm_parse(&parser);

    pm_buffer_append_byte(buffer, 0);
    pm_serialize(&parser, node, buffer);

    pm_node_destroy(&parser, node);
    pm_parser_free(&parser);
    pm_options_free(&options);
}

/**
 * Parse the source and return true if it parses without errors or warnings.
 */
PRISM_EXPORTED_FUNCTION bool
pm_parse_success_p(const uint8_t *source, size_t size, const char *data) {
    pm_options_t options = { 0 };
    pm_options_read(&options, data);

    pm_parser_t parser;
    pm_parser_init(&parser, source, size, &options);

    pm_node_t *node = pm_parse(&parser);
    pm_node_destroy(&parser, node);

    bool result = parser.error_list.size == 0 && parser.warning_list.size == 0;
    pm_parser_free(&parser);
    pm_options_free(&options);

    return result;
}
